#!/usr/bin/env python3

#
# Run the test, compare results against the benchmark
#
# requires: netcdf

from __future__ import print_function

try:
  from builtins import str
except:
  pass

from boututils.run_wrapper import shell, shell_safe, launch_safe
from boutdata.collect import collect
import numpy as np
from sys import stdout, exit



print("Making I/O test")
shell_safe("make > make.log")

# Read benchmark values

vars = ['ivar', 'rvar', 'bvar', 'f2d', 'f3d', 'fperp', 'fperp2', 'ivar_evol', 'rvar_evol',
        'bvar_evol', 'v2d_evol_x', 'v2d_evol_y', 'v2d_evol_z', 'fperp2_evol']

field_vars = ['f2d', 'f3d', 'fperp', 'fperp2', 'v2d_evol_x', 'v2d_evol_y', 'v2d_evol_z',
        'fperp2_evol'] # Field quantities, not scalars

tol = 1e-10

print("Reading benchmark data")
bmk = {}
for v in vars:
  bmk[v] = collect(v, path="data", prefix="benchmark.out", info=False)

print("Running I/O test")
success = True
for nproc in [1,2,4]:
    for split in [None, "NXPE", "NYPE"]:
        if split is not None:
            npe_max = nproc
        else:
            npe_max = 1
        for np_split in [i for i in [1,2,4] if i<=npe_max]:

            if split is not None:
                extra_args = " "+split+"="+str(np_split)
            else:
                extra_args = ""

            cmd = "./test_io" + extra_args

            # On some machines need to delete dmp files first
            # or data isn't written correctly
            shell("rm data/BOUT.dmp.*.nc")

            # Run test case
            print("   %d processor...." % (nproc) + extra_args)
            s, out = launch_safe(cmd, nproc=nproc, pipe=True)
            with open("run.log."+str(nproc), "w") as f:
                f.write(out)

            # Check processor splitting
            if split is not None:
                stdout.write("      Checking "+split+" ... ")
                v = collect(split, path="data", info=False)
                if v != np_split:
                    print("Fail, wrong "+split+" expecting %i, got %i" % (np_split, v))
                    success = False
                else:
                    print("Pass")

            # Collect output data
            for v in vars:
                stdout.write("      Checking variable "+v+" ... ")
                result = collect(v, path="data", info=False)

                # Compare benchmark and output
                if np.shape(bmk[v]) != np.shape(result):
                    print("Fail, wrong shape")
                    success = False
                    continue

                diff =  np.max(np.abs(bmk[v] - result))
                if diff > tol:
                    print("Fail, maximum difference = "+str(diff))
                    success = False
                    continue

                if v in field_vars:
                    # Check cell location
                    if "cell_location" not in result.attributes:
                        print("Fail: {0} has no cell_location attribute".format(v))
                        success = False
                        continue

                    if result.attributes["cell_location"] != "CELL_CENTRE":
                        print("Fail: Expecting cell_location == CELL_CENTRE, but got {0}".format(result.attributes["cell_location"]))
                        success = False
                        continue

                print("Pass")

if success:
  print(" => All I/O tests passed")
  exit(0)
else:
  print(" => Some failed tests")
  exit(1)
